package com.luohuasheng.spring.boot.starter.core;

import com.luohuasheng.spring.boot.starter.annotation.Lock;
import com.luohuasheng.spring.boot.starter.utils.LockException;
import org.apache.commons.lang.StringUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.*;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.expression.MethodBasedEvaluationContext;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

import static com.luohuasheng.spring.boot.starter.annotation.LockConstant.REDIS_LOCK_PREFIX;


/**
 * aop 拦截
 *
 * @author zl
 */
@Aspect
@Component
public class LockAspect {

    private ParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();

    private static ThreadLocal<String> lockKey = new ThreadLocal<>();

    private ExpressionParser parser = new SpelExpressionParser();

    @Autowired
    private LockCache lockCache;

    @Pointcut("@annotation(com.luohuasheng.spring.boot.starter.annotation.Lock)")
    public void lockFilter() {

    }

    @Before("lockFilter()")
    public void dataFilter(JoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Lock lock = signature.getMethod().getAnnotation(Lock.class);
        String keyName = getKeyName(point,lock);
        String key = REDIS_LOCK_PREFIX+keyName;
        boolean result = lockCache.lock(key);
        if(!result){
            throw new LockException(lock.errorMsg());
        }
        try {
            lockKey.set(key);
        }catch (Exception e){
            e.printStackTrace();
        }
    }

    /**
     * 释放锁
     */
    @AfterReturning("lockFilter()")
    public void afterFilter(JoinPoint point){
        releaseLock(point);
    }


    /**
     * 释放锁
     */
    @After("lockFilter()")
    public void after(JoinPoint point){
        releaseLock(point);
    }

    /**
     * 释放锁
     */
    @AfterThrowing(value = "lockFilter()",throwing = "ex")
    public void throwFilter(JoinPoint point,Throwable ex){
        if(ex instanceof LockException){
            //如果是锁本身的异常不需要释放锁
            return;
        }
        releaseLock(point);
    }


    private void releaseLock(JoinPoint point){
        MethodSignature signature = (MethodSignature) point.getSignature();
        Lock lock = signature.getMethod().getAnnotation(Lock.class);
        String keyName = getKeyName(point,lock);
        String key = REDIS_LOCK_PREFIX+keyName;
        lockCache.unLock(key);
        String localKey = lockKey.get();
        if(!ObjectUtils.isEmpty(localKey)){
            lockCache.unLock(localKey);
        }
    }

    private String getKeyName(JoinPoint joinPoint,Lock lock){
        Object[] args = joinPoint.getArgs();
        String[] keys = lock.lockKey();
        Method method = getMethod(joinPoint);
        List<String> definitionKeyList = new ArrayList<>();
        for (String definitionKey : keys) {
            if (!ObjectUtils.isEmpty(definitionKey)) {
                EvaluationContext context = new MethodBasedEvaluationContext(null, method, args, nameDiscoverer);
                Object objKey = parser.parseExpression(definitionKey).getValue(context);
                definitionKeyList.add(ObjectUtils.nullSafeToString(objKey));
            }
        }
        return method.getName()+"-"+ StringUtils.join(definitionKeyList,"-");
    }


    private Method getMethod(JoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        if (method.getDeclaringClass().isInterface()) {
            try {
                method = joinPoint.getTarget().getClass().getDeclaredMethod(signature.getName(),
                        method.getParameterTypes());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return method;
    }
}
